Skip to content

Conversation

@HosseinKaviani-H
Copy link
Contributor

Add Multi-Node Distributed Training Support for SLURM Clusters

Summary

This PR adds comprehensive multi-node distributed training support to Forge, enabling scalable training on SLURM-managed GPU clusters. Successfully tested on a 32-node cluster with 128 GPUs (4 GPUs per node) training Qwen3-32B.

Motivation

The existing Forge implementation only supported single-node multi-GPU training. This PR extends support to multi-node environments, which is essential for:

  • Training large models (30B+ parameters) that require 100+ GPUs
  • Leveraging institutional SLURM clusters
  • Scaling beyond single-node hardware limitations

Key Changes

1. Multi-Node LOCAL_RANK Fix (apps/sft/main.py)

  • Problem: Original code set LOCAL_RANK = RANK, which breaks on multi-node setups (node 2 would have LOCAL_RANK=4 instead of 0)
  • Solution: Calculate LOCAL_RANK = RANK % gpus_per_node to ensure proper GPU assignment per node
  • Impact: Critical for correct CUDA device mapping in multi-node environments

2. SLURM Provisioner Integration (apps/sft/main.py)

  • Added provisioner initialization with launcher configuration support
  • Supports both actors and legacy processes config for backward compatibility
  • Proper shutdown sequence: cleanup → actor shutdown → provisioner shutdown
  • Try-except handling for graceful KeyboardInterrupt

3. Smart SLURM Resource Detection (src/forge/controller/launcher.py)

  • Auto-detection from SLURM environment variables (SLURM_CPUS_ON_NODE, SLURM_MEM_PER_NODE, SLURM_GPUS_PER_NODE)
  • Fallback to scontrol show node when env vars unavailable
  • Flexible configuration: Resources can be specified in YAML or auto-detected
  • Handles different SLURM versions and GPU format variations ("4" vs "gpu:4")

4. Configurable Data Loading (apps/sft/main.py)

  • Added num_shards_per_rank parameter (default: 64 for large datasets, 8 for small)
  • Added num_dataloader_workers parameter (default: 0 to avoid CUDA fork issues)
  • Better I/O parallelism control for different dataset sizes

Testing

  • ✅ Tested on 32-node SLURM cluster (128 total GPUs)
  • ✅ Successfully trains Qwen3-32B (32B parameters)
  • ✅ Verified NCCL communication across nodes
  • ✅ Confirmed checkpoint saving/loading
  • ✅ WandB metric logging functional

Files Changed

  • apps/sft/main.py - Multi-node support, provisioner integration, data config
  • apps/sft/qwen3_32b.yaml - New optimized 128-GPU config (renamed from qwen3_32b_multinode.yaml)
  • **qwen3_32b.yaml** - GRPO config for Qwen3-32B
  • src/forge/controller/launcher.py - SLURM resource auto-detection
  • src/forge/types.py - Type definitions for launcher config

Usage Example

YAML Config:

provisioner: launcher: slurm cpu: 128 memory_mb: 1655502 gpus_per_node: 4 actors: trainer: procs: 4 hosts: 32 with_gpus: true

Run Training:

python -m apps.sft.main --config apps/sft/qwen3_32b.yaml

Hossein Kavianihamedani added 2 commits November 5, 2025 08:33
- Add multi-node training support in main.py with proper LOCAL_RANK calculation
- Add qwen3_32b.yaml config optimized for 32-node, 128 GPU training
- Add qwen3_32b.yaml config for GRPO training
- Update launcher.py with SLURM resource auto-detection from environment
- Update types.py with necessary type definitions

Key features:
- Proper multi-node LOCAL_RANK: rank % gpus_per_node (fixes cross-node issues)
- Provisioner support for SLURM multi-node orchestration
- SLURM resource inference from environment variables and scontrol
- Configurable data loading: num_shards_per_rank, num_dataloader_workers
- Optimized training config: TP=4, FSDP=32, selective AC every 2 layers
- Async checkpointing enabled for non-blocking saves
- Backward compatibility with legacy 'processes' config

Optimizations applied:
- Activation checkpointing: selective with layer frequency 2 (2-3x faster)
- Async checkpointing: non-blocking background saves
- Batch size 8 with gradient accumulation 2 for convergence
- 64 shards per rank for optimal I/O parallelism
- SLURM_SWITCHES=2 for network locality (18 nodes/block topology)

Tested on:
- 32 nodes × 4 GPUs = 128 total GPUs
- Ethernet network with SLURM block topology
- Qwen3-32B model (32B parameters)
- 5,000 training steps with WandB logging
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 5, 2025
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this mostly looks fine to me, but I would like @daniellepintz to take a look as well over the SFT pieces!


super().__init__(job_config)

def _init_dist(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we remove the _init_dist altogether would this still work? I added this line in get_proc_mesh later, so this should not be needed anymore. Could you please try it out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to test this. Right now, I pass the local rank and NCCL variables within the env there. Will keep you posted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should split out the SLURM specific PR from the SFT PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that's a reasonable approach. I'll think of how to separate them and raise a new one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants